strided BLAS DGEMM path for ToT einsum contractions#559
Open
zhihao-deng wants to merge 15 commits into
Open
Conversation
…A_STRIDED_DGEMM_COUNT option
…atch==1 assert) + tile tests
…ntEngine + e2e (nbatch>1)
…ContEngine + e2e (left-external Mo>1)
…left-external + nbatch>1) + benches
Route the regime-A hc+e einsum (outer Hadamard + outer contraction, inner outer-product) through the landed arena_strided_dgemm_ce_e core (M=N=1, K=tile volume) in run_regime_a_arena, replacing the per-cell rank-1 dger loop with one strided DGEMM per outer-contraction tile. Gated to view+double arena ToT contraction with num_contract_ranks()==0; all other kinds keep the per-cell path. Adds a regime_a_strided_disabled() kill switch, tile/e2e/differential/edge tests, and a strided-vs-per-cell benchmark (~7.3x on a C6H14-like shape).
… (either-side hce+ce)
… timing probe, bench & tests
… ranks The einsum_tot arena-matches-owning tests iterate over all result tile ordinals but only inspect tiles local to the calling rank, then assert the per-rank elements_compared / result_outer_cells_seen counts (and the fatal BOOST_REQUIRE_GT(elements_compared, 0u)) against the global expected totals. That holds under np=1 (all tiles local) but fails under np=2: each rank sees only its share, and a rank owning no result tiles trips the REQUIRE_GT. All-reduce the accumulators (gop.sum on the counts, gop.max on max_abs_diff) before the assertions so every rank checks the true global totals. Fixes the 14 np=2 einsum_tot failures.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Lift per-cell ToT (ArenaTensor) einsum work to BLAS-3 GEMM wherever possible,
instead of looping per-cell ops.
Recast the following ArenaTensor einsum cases as strided GEMM:
ce+e, inner outer-product): ride the outer-contractionindex into BLAS K
ce+ce, inner contraction — guarded subset, not the generalcase): ride the outer-external index into BLAS M.
Everything outside these guarded regimes keeps the existing per-cell path, so
behavior is unchanged elsewhere.
Guards
A strided GEMM fires only when the cell run is "clean": all cells present,
uniform inner size, and a single constant inter-cell stride. Empty inners punch
holes that break contiguity, so we fall to segmented kernels: walk each run
and emit one strided GEMM per maximal contiguous segment of present cells,
skipping the holes (accumulating with β=1 across segments).
Notes
Still carries env-gated diagnostics (
TA_GEMM_TIMING,TA_STRIDED_DGEMM_VERBOSE,and the
TA_STRIDED_DGEMM_COUNTbuild counters)